!! Copyright (C) 2009,2010,2011,2012  Marco Restelli
!!
!! This file is part of:
!!   FEMilaro -- Finite Element Method toolkit
!!
!! FEMilaro is free software; you can redistribute it and/or modify it
!! under the terms of the GNU General Public License as published by
!! the Free Software Foundation; either version 3 of the License, or
!! (at your option) any later version.
!!
!! FEMilaro is distributed in the hope that it will be useful, but
!! WITHOUT ANY WARRANTY; without even the implied warranty of
!! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
!! General Public License for more details.
!!
!! You should have received a copy of the GNU General Public License
!! along with FEMilaro; If not, see <http://www.gnu.org/licenses/>.
!!
!! author: Marco Restelli                   <marco.restelli@gmail.com>

!>\brief
!! Linear system for the ldgh code.
!!
!! \n
!!
!! This module contains the definition of the linear system for the
!! ldgh solver, including the declarations of the related variables.
!! The linear solvers are also handled here.
!<----------------------------------------------------------------------
module mod_ldgh_linpb

!-----------------------------------------------------------------------

 use mod_utils, only: &
   t_realtime, my_second

 use mod_messages, only: &
   mod_messages_initialized, &
   error,   &
   warning, &
   info

 use mod_kinds, only: &
   mod_kinds_initialized, &
   wp

 use mod_state_vars, only: &
   mod_state_vars_initialized, &
   c_stv

 use mod_linal, only: &
   mod_linal_initialized, &
   invmat

 use mod_sparse, only: &
   mod_sparse_initialized, &
   t_col!,       &
   !controllare cosa serve
   !t_tri,       &
   !t_pm_sk,     &
   !! overloaded operators
   !transpose,   &
   !clear

 use mod_mpi_utils, only: &
   mod_mpi_utils_initialized, &
   mpi_status_size, mpi_request_null, mpi_integer, &
   mpi_isend, mpi_irecv, mpi_wait, mpi_waitall,    &
   mpi_bcast

 use mod_linsolver, only: &
   mod_linsolver_initialized, &
   c_linpb, c_mumpspb

 use mod_base, only: &
   mod_base_initialized, &
   t_base

 use mod_grid, only: &
   mod_grid_initialized, &
   t_grid, t_ddc_grid

 use mod_bcs, only: &
   mod_bcs_initialized, &
   t_bcs,                     &
   t_b_v,   t_b_s,   t_b_e,   &
   p_t_b_v, p_t_b_s, p_t_b_e, &
   b_dir,   b_neu,   b_ddc

!-----------------------------------------------------------------------
 
 implicit none

!-----------------------------------------------------------------------

! Module interface

 public :: &
   mod_ldgh_linpb_constructor, &
   mod_ldgh_linpb_destructor,  &
   mod_ldgh_linpb_initialized, &
   dofs_dir, dofs_nat, gdofs_nat, &
   t_lambda, lam, lam1, lam2,     &
   fff, fff1, fff2, rhs,       &
   edim, eldim, sdim, nnz_mmm11,  &
   mmmi, mmmj, mmmx, mmm, mmm11, mmm12, mmm21, mmm22, &
   linpb,                      &
   me_g2l

 private

!-----------------------------------------------------------------------

! Module types and parameters

 ! public members

 !> Type for the lagrange multiplier
 type, extends(c_stv) :: t_lambda
  !> The degrees of freedom are arranged in a 1D array; the major
  !! ordering comes from the side ordering, the minor one from the
  !! local dofs ordering.
  real(wp), allocatable :: l(:)
 contains
  procedure, pass(x) :: incr => l_incr
  procedure, pass(x) :: tims => l_tims
  procedure, pass(z) :: copy => l_copy
  procedure, pass(x) :: source => l_source
  procedure, pass(x) :: source_vect => l_source_vect
 end type t_lambda

 !> MUMPS solver
 type, extends(c_mumpspb) :: t_mumpspb
 contains
  procedure, pass(s) :: xassign => mumps2lambda
 end type t_mumpspb

 ! private members

! Module variables

 ! local matrices corrections
 real(wp), allocatable :: me_g2l(:,:,:)

 ! global matrix
 integer :: edim, eldim, sdim, nnz_mmm11
 integer, allocatable :: mmmi(:), mmmj(:)
 type(t_lambda) :: lam1
 real(wp), allocatable :: mmmx(:), fff(:), fff1(:), fff2(:), &
  lam(:), lam2(:)
 real(wp), allocatable , target :: rhs(:)
 type(t_col), save, target :: mmm, mmm11, mmm12, mmm21, mmm22

 ! linear problem
 class(c_linpb), allocatable :: linpb

 ! bcs
 integer, allocatable :: dofs_dir(:), dofs_nat(:)
 integer, allocatable, target :: gdofs_nat(:)

 ! public members
 logical, protected ::               &
   mod_ldgh_linpb_initialized = .false.
 
 ! private members

 character(len=*), parameter :: &
   this_mod_name = 'mod_ldgh_linpb'

!-----------------------------------------------------------------------

contains

!-----------------------------------------------------------------------

 !> Define all the variables related to the linear system
 !!
 !! This subroutine also defines the global numbering of the degrees
 !! of freedom. More precisely, we need a global numbering for
 !! <em>natural</em> degrees of freedom, since Dirichlet dofs can be
 !! dealt with completely at the local level. Notice that shared
 !! degrees of freedom must be labeled consistently; this is ensured
 !! by using a sequential algorithm to number <em>natural</em> sides,
 !! to which natural degrees of freedom are associated, and by
 !! introducing local conversion matrices.
 !!
 !! The algorithm used to assign a unique global index to natural
 !! sides is similar to those used, for instance, in \c mod_grid. More
 !! interesting is the problem of ensuring that on each shared side
 !! the degrees of freedom are defined consistently. In fact, given
 !! that a shared side receives a unique global index, one is left
 !! with the problem of defining the local degrees of freedom. In the
 !! present code, such local degrees of freedom are defined following
 !! the ordering of the side nodes, which is in general domain
 !! dependent. One possibility would be redefining the local ordering;
 !! however this would go against the general rule that local problems
 !! are independent from the global one. The alternative is to
 !! introduce <em>two</em> sets of local degrees of freedom, and
 !! defining as well a transformation matrix. This is done as follows:
 !! <ul>
 !!  <li> on each side, the subdomain with lower index defines the
 !!  global degrees of freedom \f$\lambda\f$, without further changes;
 !!  <li> on each side, the subdomain with higher index defines local
 !!  degrees of freedom \f$\lambda_l\f$ as well as a change of basis
 !!  matrix \f$M_{e,g\to l}\f$ such that
 !!  \f{displaymath}{
 !!   \lambda_l = M_{e,g\to l}\lambda.
 !!  \f}
 !! </ul>
 !! Then on the second subdomain, all the local matrices which include
 !! the side functions must be multiplied by \f$M_{e,g\to l}\f$ on the
 !! left (test functions) and/or on the right (shape functions), since
 !! \f{displaymath}{
 !!  \mu_l^TA\lambda_l = \mu^T M_{e,g\to l}^T\,A\,M_{e,g\to l}\lambda.
 !! \f}
 !! No other changes are required. Thanks to the fact that the side
 !! basis functions are orthogonal, we have
 !! \f{displaymath}{
 !!  \left(M_{e,g\to l}\right)_{ij} = \int_e
 !!  \eta_{l_i}\eta_j\,d\sigma,
 !! \f}
 !! and, since each \f$\eta_{l_i}\f$ is orthogonal to all the
 !! polynomials of lower degree, the matrix will have a block diagonal
 !! structure and will be orthogonal.
 subroutine mod_ldgh_linpb_constructor( grid,ddc_grid,base,bcs,comm, &
                                        zero_mean )
  type(t_grid),     intent(in) :: grid
  type(t_ddc_grid), intent(in) :: ddc_grid
  type(t_base),     intent(in) :: base
  type(t_bcs),      intent(in) :: bcs
  integer,          intent(in) :: comm
  logical,          intent(in) :: zero_mean

  integer :: i, j, is, is_dir, is_nat, gis_nat, req, ierr
  integer, allocatable :: recvbuf(:,:), reqs(:), mpi_stat(:,:), &
    idx(:), sendbuf(:,:), is_s(:), read_recvbuf(:)
  real(wp) :: tmp
  character(len=*), parameter :: &
    this_sub_name = 'constructor'

   !Consistency checks ---------------------------
   if( (mod_messages_initialized.eqv..false.) .or. &
          (mod_kinds_initialized.eqv..false.) .or. &
     (mod_state_vars_initialized.eqv..false.) .or. &
          (mod_linal_initialized.eqv..false.) .or. &
         (mod_sparse_initialized.eqv..false.) .or. &
      (mod_mpi_utils_initialized.eqv..false.) .or. &
      (mod_linsolver_initialized.eqv..false.) .or. &
           (mod_base_initialized.eqv..false.) .or. &
           (mod_grid_initialized.eqv..false.) .or. &
            (mod_bcs_initialized.eqv..false.) ) then
     call error(this_sub_name,this_mod_name, &
                'Not all the required modules are initialized.')
   endif
   if(mod_ldgh_linpb_initialized.eqv..true.) then
     call warning(this_sub_name,this_mod_name, &
                  'Module is already initialized.')
   endif
   !----------------------------------------------

   ! Allocate module arrays
   allocate(dofs_dir( base%nk * count(bcs%b_side%bc.eq.b_dir) ))
   if(.not.zero_mean) then
     allocate(dofs_nat( base%nk*grid%ns - size(dofs_dir) ))
   else
     allocate(dofs_nat( base%nk*grid%ns+1 - size(dofs_dir) ))
   endif
   allocate(gdofs_nat( size(dofs_nat) ))
   allocate(me_g2l(base%nk,base%nk,size(base%stab,1)))
   ! Work arrays for synchronization
   allocate( recvbuf(2,sum(ddc_grid%nns_id(:ddc_grid%id-1))) )
   allocate( reqs    (                0:ddc_grid%nd-1) , &
             mpi_stat(mpi_status_size,0:ddc_grid%nd-1) )

   ! 1) Get information from previous subdomains
   ! get the global counter
   if(ddc_grid%id.gt.0) then
     call mpi_irecv( gis_nat , 1 , mpi_integer ,          &
                     ddc_grid%id-1, 1 , comm , req , ierr )
   else
     gis_nat = 0 ! set the global counter
   endif
   ! get the shared sides
   reqs = mpi_request_null
   is = 1
   do i=0,ddc_grid%id-1
     if(ddc_grid%nns_id(i).gt.0) &
       call mpi_irecv(recvbuf(1,is),2*ddc_grid%nns_id(i),mpi_integer,&
                      i, 2, comm, reqs(i), ierr )
     is = is + ddc_grid%nns_id(i)
   enddo
   ! wait for the global counter; prepare working arrays for step 2
   allocate( idx(base%nk) ); idx = (/ (i , i=0,base%nk-1) /)
   allocate( sendbuf(2,sum(ddc_grid%nns_id(ddc_grid%id+1:))) )
   allocate( read_recvbuf(ddc_grid%nns) )
   allocate( is_s(ddc_grid%id+1:ddc_grid%nd-1) ) ! shifts in sendbuf
   do i=ddc_grid%id+1,ddc_grid%nd-1
     is_s(i) = sum(ddc_grid%nns_id(ddc_grid%id+1:i-1))
   enddo
   if(ddc_grid%id.gt.0) call mpi_wait( req , mpi_stat(:,0) , ierr )

   ! 2.1) Local subdomain, internal sides
   is_dir = 0; is_nat = 0 ! side counters: Dirichlet, natural
   int_s_do: do is=1,grid%ni ! internal sides are natural sides
     is_nat = is_nat+1
     dofs_nat( (is_nat-1)*base%nk + idx+1 ) = (is-1)*base%nk + idx
     gis_nat = gis_nat + 1
     gdofs_nat((is_nat-1)*base%nk + idx+1 ) = (gis_nat-1)*base%nk + idx
   enddo int_s_do
   ! 2.2) Local subdomain, boundary sides - must wait for recvbuf
   call mpi_waitall( size(reqs) , reqs , mpi_stat , ierr )
   read_recvbuf( ddc_grid%gs(recvbuf(1,:))%ni ) = recvbuf(2,:)
   deallocate(recvbuf)
   bdr_s_do: do is=grid%ni+1,grid%ns
     b_type: select case(bcs%b_s2bs(is)%p%bc)
      case(b_neu) ! same as internal sides
       is_nat = is_nat+1
       dofs_nat( (is_nat-1)*base%nk + idx+1 ) = (is-1)*base%nk + idx
       gis_nat = gis_nat + 1
       gdofs_nat((is_nat-1)*base%nk + idx+1) = (gis_nat-1)*base%nk + idx
      case(b_ddc)
       ! locally this is a nat side
       is_nat = is_nat+1
       dofs_nat( (is_nat-1)*base%nk + idx+1 ) = (is-1)*base%nk + idx
       ! the global numbering requires synchronization
       i = ddc_grid%ns( ddc_grid%gs(is)%ni )%id
       if(i.lt.ddc_grid%id) then ! read the relevant information
         j = read_recvbuf( ddc_grid%gs(is)%ni ) ! global side index
         gdofs_nat((is_nat-1)*base%nk + idx+1) = (j-1)*base%nk + idx
       else ! define new global degrees of freedom
         gis_nat = gis_nat + 1
         gdofs_nat((is_nat-1)*base%nk + idx+1) = (gis_nat-1)*base%nk+idx
         ! other subdomains must be informed
         is_s(i) = is_s(i) + 1
         sendbuf(1,is_s(i)) = ddc_grid%ns( ddc_grid%gs(is)%ni )%in
         sendbuf(2,is_s(i)) = gis_nat
       endif
      case(b_dir)
       is_dir = is_dir+1
       dofs_dir( (is_dir-1)*base%nk + idx+1 ) = (is-1)*base%nk + idx
     end select b_type
   enddo bdr_s_do
   ! deallocations afterwards

   ! 3) Send information to subsequent subdomains
   if(ddc_grid%id.lt.ddc_grid%nd-1) &
     call mpi_isend( gis_nat , 1 , mpi_integer ,          &
                     ddc_grid%id+1, 1 , comm , req , ierr )
   reqs = mpi_request_null
   is = 1
   do i=ddc_grid%id+1,ddc_grid%nd-1
     if(ddc_grid%nns_id(i).gt.0) &
       call mpi_isend(sendbuf(1,is),2*ddc_grid%nns_id(i),mpi_integer,&
                      i, 2, comm, reqs(i) , ierr )
     is = is + ddc_grid%nns_id(i)
   enddo
   ! waiting and deallocations afterwards

   ! 4.1) Side transformation matrices
   me_g2l = 0.0_wp
   do j=1,base%nk
     tmp = sum(base%wgs * base%e(j,:)**2) ! squared norm
     do is=1,size(base%stab,1)
       do i=1,base%nk
         me_g2l(i,j,is) = sum( base%wgs(base%stab(is,:))     &
             * base%e(i,base%stab(is,:)) * base%e(j,:) ) / tmp
       enddo
     enddo
   enddo
   ! 4.2) Wait and deallocations
   if(ddc_grid%id.lt.ddc_grid%nd-1) &
     call mpi_wait( req , mpi_stat(:,ddc_grid%id+1) , ierr )
   call mpi_waitall( size(reqs) , reqs , mpi_stat , ierr )
   deallocate(sendbuf,reqs,mpi_stat,idx,is_s,read_recvbuf)

   ! 5) The last subdomain communicates the total number of nat sides
   call mpi_bcast(gis_nat,1,mpi_integer,ddc_grid%nd-1,comm,ierr)
   ! 5.1) Add the zero mean degree of freedom if necessary
   if(zero_mean) then
     dofs_nat (size(dofs_nat)) = base%nk*grid%ns ! zero based indexes
     gdofs_nat(size(dofs_nat)) = base%nk*gis_nat
   endif

   ! 6) Now we can define the linear problem

   ! Note: pointers can be pointed to allocatable objects only when
   ! such objects are already allocated (and pointers would loose
   ! their association status once the target is deallocated). So, we
   ! have to anticipate here some allocations.
   allocate(rhs(size(dofs_nat)))

   select case('mumps') ! make this more general if you want
    case('mumps')
     allocate(t_mumpspb::linpb)
     select type(linpb); type is(t_mumpspb)
      linpb%distributed = .true.
      linpb%gn  = base%nk * gis_nat
      if(zero_mean) linpb%gn = linpb%gn + 1
      linpb%m   => mmm11
      linpb%rhs => rhs
      linpb%gij => gdofs_nat
      linpb%transposed_mat = .false.
      linpb%mpi_comm = comm
     end select
    case default
     call error(this_sub_name,this_mod_name, &
                'Unknown linear solver.')
   end select

   mod_ldgh_linpb_initialized = .true.
 end subroutine mod_ldgh_linpb_constructor

!-----------------------------------------------------------------------
 
 subroutine mod_ldgh_linpb_destructor()
  character(len=*), parameter :: &
    this_sub_name = 'destructor'

   !Consistency checks ---------------------------
   if(mod_ldgh_linpb_initialized.eqv..false.) then
     call error(this_sub_name,this_mod_name, &
                'This module is not initialized.')
   endif
   !----------------------------------------------

   deallocate(linpb)
   deallocate(dofs_dir,dofs_nat,gdofs_nat,me_g2l)
   deallocate(rhs)
   mod_ldgh_linpb_initialized = .false.
 end subroutine mod_ldgh_linpb_destructor

!-----------------------------------------------------------------------
 
 subroutine l_incr(x,y)
  class(c_stv),    intent(in)    :: y
  class(t_lambda), intent(inout) :: x
   
   select type(y); type is(t_lambda)
    x%l = x%l + y%l
   end select
 end subroutine l_incr

!-----------------------------------------------------------------------

 subroutine l_tims(x,r)
  real(wp),        intent(in)    :: r
  class(t_lambda), intent(inout) :: x

   x%l = r*x%l
 end subroutine l_tims

!-----------------------------------------------------------------------

 subroutine l_copy(z,x)
  class(c_stv),    intent(in)    :: x
  class(t_lambda), intent(inout) :: z

   select type(x); type is(t_lambda)
    z%l = x%l
   end select
 end subroutine l_copy

!-----------------------------------------------------------------------

 !> Workaround for ifort compiler bug (see \c mod_state_vars)
 subroutine l_source(y,x)
  class(t_lambda), intent(in)  :: x
  class(c_stv), allocatable, intent(out) :: y

   allocate(t_lambda::y)
   select type(y); type is(t_lambda)
    allocate(y%l(size(x%l)))
    y%l = x%l
   end select
 end subroutine l_source
 subroutine l_source_vect(y,x,m)
  integer, intent(in) :: m
  class(t_lambda), intent(in)  :: x
  class(c_stv), allocatable, intent(out) :: y(:)

  integer :: i

   allocate(t_lambda::y(m))
   select type(y); type is(t_lambda)
    do i=1,m
      allocate(y(i)%l(size(x%l)))
      y(i)%l = x%l
    enddo
   end select
 end subroutine l_source_vect

!-----------------------------------------------------------------------

 subroutine mumps2lambda(x,s,mumps_x)
  real(wp),         intent(in) :: mumps_x(:)
  class(t_mumpspb), intent(inout) :: s
  class(c_stv),     intent(inout) :: x

   select type(x); type is(t_lambda)
    x%l = mumps_x
   end select
 end subroutine mumps2lambda

!-----------------------------------------------------------------------

end module mod_ldgh_linpb

